#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)
import matplotlib.colors as colours
import matplotlib.patches as patches
import matplotlib.pyplot as plt
# from astropy.visualization import make_lupton_rgb
# from astropy.table import Table
import ch4_fiddlesticks as ch4
# import pandas as pd
import h3ppy
import glob
# import spiceypy as spice
# from astropy.io import fits
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
reload(jssp)
#import jwst_uranus as jwstu
from scipy.interpolate import interp1d

# In[0]

kernel_dir = '/Users/jcfq2/data/observations/jwst/kernels/'
jssp.load_kernels(kdir=kernel_dir)

# # Set up a h3ppy object too, always useful
h3p_model = h3ppy.h3p()




# =============================================================================
# In[0] Set up a color table
# =============================================================================

colors1 = plt.cm.hot(np.linspace(0, 0.83, 256))

# combine them and build a new colormap
# colors = np.vstack((colors1, colors2))

# colors[125:131]=[0.5,0.5,0.5,0.8]
hot_crop = colours.LinearSegmentedColormap.from_list('hot_crop', colors1)





# In[1]


import cartopy.crs as ccrs
crs = ccrs.RotatedPole(globe=ccrs.Globe(flattening=(0.0)))

 
files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))
geo = jssp.JWSTSolarSystemPointing(files[0])
wave = geo.get_wavelength()
#  just to get geo for conversion

nam=np.load('saturn_spectra_map.npy')
nac=np.load('saturn_spectra_count.npy')
# nam=np.load('saturn_spectra_map_2_los.npy')
# nac=np.load('saturn_spectra_count_2_los.npy')
# nam=np.load('saturn_spectra_map_2_narrow.npy')
# nac=np.load('saturn_spectra_count_2_narrow.npy')

scale=1

nam[-90*scale:-45*scale,:,:]=nam[-90*scale:-45*scale,:,:]+nam[0:45*scale,:,:]
nac[-90*scale:-45*scale,:,:]=nac[-90*scale:-45*scale,:,:]+nac[0:45*scale,:,:]
nam[45*scale:90*scale,:,:]=nam[45*scale:90*scale,:,:]+nam[-45*scale:,:,:]
nac[45*scale:90*scale,:,:]=nac[45*scale:90*scale,:,:]+nac[-45*scale:,:,:]


sat_T=nam[:,:,600]/nac[:,:,600]


ab=nam[:,:,2239]+nam[:,:,1568]+nam[:,:,1642]+nam[:,:,1010]+nam[:,:,1011]+nam[:,:,1018]+nam[:,:,1019]
cd=(nam[:,:,2241]*0.8*0.5+nam[:,:,2236]*1.19*0.5)+((nam[:,:,1571]+nam[:,:,1565])*0.5)+(0.4*nam[:,:,1640]/5.21)+(nam[:,:,1004]+nam[:,:,1005]*2)


img_h3p=(ab-cd)/nac[:,:,600]

# Set a wavelength range where the non-LTE CH4 is
whw = np.argwhere((wave > 3) & (wave < 4.6)).flatten()
print(whw)
# Fit the non-LTE CH4 background
# crs = ccrs.RotatedPole(globe=ccrs.Globe(flattening=(0.0)))
crs = ccrs.NorthPolarStereo(globe=ccrs.Globe(flattening=(0.0)))



maxint=np.nanmax(img_h3p[45*scale:405*scale,:].T)

map_h3p=np.fliplr(img_h3p[45*scale:405*scale,:].T)/maxint



spec_cube=nam/nac*1000
spec_cube=np.nan_to_num(spec_cube)


# could be problematic - lots going on in the background 
file_331='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_331.txt'
bck_331 = np.loadtxt(file_331)


# could be problematic - lots going on in the background 
file_345='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_345.txt'
bck_345 = np.loadtxt(file_345)



# could be problematic - lots going on in the background 
file_353='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_353.txt'
bck_353 = np.loadtxt(file_353)


file_360='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_360_363.txt'
bck_360 = np.loadtxt(file_360)



file_366='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_366_3675.txt'
bck_366 = np.loadtxt(file_366)


#

file_371_372='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_371_372.txt'
bck_371_372 = np.loadtxt(file_371_372)


# horrid background
# file_388_389='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_388_389.txt'
# bck_388_389 = np.loadtxt(file_388_389)

file_390_391='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_390_391.txt'
bck_390_391 = np.loadtxt(file_390_391)


file_392_3935='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_392_3935.txt'
bck_392_3935 = np.loadtxt(file_392_3935)

file_393_3958='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_3930_3959.txt'
bck_393_3958 = np.loadtxt(file_393_3958)


file_396_398='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_396_398.txt'
bck_396_398 = np.loadtxt(file_396_398)


file_398_399='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_398_399.txt'
bck_398_399 = np.loadtxt(file_398_399)


file = files[-2]
# Create a JWSTSolarSystemPointing object that helps with lots of things JWST
geo = jssp.JWSTSolarSystemPointing(file)

# Caluclate the geometry for the full IFU cube. This is a three dimensional arrray.
cube = geo.full_fov()

# Get the wavelength scale for this observation - should be the same for all NIRSpec G395H/F290L observations.
wave = geo.get_wavelength()

# this didn't work with PSG, and is outside the range used by CH4
wavemin_331 = 3.31
wavemax_331 = 3.34
whw_331 = np.argwhere((wave > wavemin_331) & (wave < wavemax_331)).flatten()


wavemin_345 = 3.445
wavemax_345 = 3.465
whw_345 = np.argwhere((wave > wavemin_345) & (wave < wavemax_345)).flatten()


wavemin_353 = 3.5255
wavemax_353 = 3.565
whw_353 = np.argwhere((wave > wavemin_353) & (wave < wavemax_353)).flatten()


wavemin_360 = 3.61
wavemax_360 = 3.635
whw_360 = np.argwhere((wave > wavemin_360) & (wave < wavemax_360)).flatten()


wavemin_366 = 3.662
wavemax_366 = 3.675
whw_366 = np.argwhere((wave > wavemin_366) & (wave < wavemax_366)).flatten()


# correction for 3.9008

wavemin_371 = 3.709
wavemax_371 = 3.721
whw_371 = np.argwhere((wave > wavemin_371) & (wave < wavemax_371)).flatten()


# correction for 3.9008

wavemin_390 = 3.898
wavemax_390 = 3.911
whw_390 = np.argwhere((wave > wavemin_390) & (wave < wavemax_390)).flatten()


# correction for 3953

wavemin_392 = 3.923
wavemax_392 = 3.9313
whw_392 = np.argwhere((wave > wavemin_392) & (wave < wavemax_392)).flatten()


# correction for 3953

wavemin_3953 = 3.93
wavemax_3953 = 3.958
whw_3953 = np.argwhere((wave > wavemin_3953) & (wave < wavemax_3953)).flatten()

# correction for 3953
wavemin_396 = 3.961
wavemax_398 = 3.979
whw_3966 = np.argwhere((wave > wavemin_396) & (wave < wavemax_398)).flatten()

# correction for 3953
wavemin_398 = 3.98
wavemax_399 = 3.991
whw_398 = np.argwhere((wave > wavemin_398) & (wave < wavemax_399)).flatten()

# to fit temperature

# wavemin = 3.3
# wavemax = 4.3



# whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()

h3p = h3ppy.h3p()
h3p.set(T=500, N=6e14, wave=wave, R=2700)
h3p_model=h3p.model()

im_h3p = np.nanmedian(geo.im[whw_3953, :, :], axis=0)
 
img = nam[:,:,1000]



# Create an interpolation function
f_331 = interp1d(bck_331[:,0], bck_331[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_345 = interp1d(bck_345[:,0], bck_345[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_353 = interp1d(bck_353[:,0], bck_353[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_360 = interp1d(bck_360[:,0], bck_360[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_366 = interp1d(bck_366[:,0], bck_366[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_371 = interp1d(bck_371_372[:,0], bck_371_372[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_390 = interp1d(bck_390_391[:,0], bck_390_391[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_392 = interp1d(bck_392_3935[:,0], bck_392_3935[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_3953 = interp1d(bck_393_3958[:,0], bck_393_3958[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_3966 = interp1d(bck_396_398[:,0], bck_396_398[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
f_398 = interp1d(bck_398_399[:,0], bck_398_399[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'

# Interpolate y values to the new x-range
bck_331_whw = f_331(wave[whw_331])
bck_345_whw = f_345(wave[whw_345])
bck_353_whw = f_353(wave[whw_353])
bck_360_whw = f_360(wave[whw_360])
bck_366_whw = f_366(wave[whw_366])
bck_371_372_whw = f_371(wave[whw_371])
bck_390_391_whw = f_390(wave[whw_390])
bck_392_3935_whw = f_392(wave[whw_392])
bck_393_3958_whw = f_3953(wave[whw_3953])
bck_396_398_whw = f_3966(wave[whw_3966])
bck_398_399_whw = f_398(wave[whw_398])





# %%

# sat_N=nam[:,:,682]/nac[:,:,682]

plotcol=np.array(['dodgerblue','limegreen','tomato'])

temp_ratio_factors=np.array([198.58633939662195,227.15566799054042,37.28510863758654])
rq_int_ratio_factors = np.array([-31.121753993057936,0.1082841194633575,-0.0001632594032538115,9.08369715525734e-08])


# cheats in the earlier fits
wavemin_331 = 3.27
wavemax_331 = 3.47
whw_331 = np.argwhere((wave > wavemin_331) & (wave < wavemax_331)).flatten()


# cheats in the earlier fits
wavemin_4 = 3.97
wavemax_4 = 3.99
whw_4 = np.argwhere((wave > wavemin_4) & (wave < wavemax_4)).flatten()



fig2, axs = plt.subplots(ncols=3,nrows=6,figsize=(12,12.5),dpi=300, height_ratios=[1,0.25,1,1,1,1])

axs[0,0].axis('off')
axs[0,1].axis('off')
axs[0,2].axis('off')
axs[1,0].axis('off')
axs[1,1].axis('off')
axs[1,2].axis('off')
fig2.subplots_adjust(wspace=0)
# fig2.subplots_adjust(hspace=0,wspace=0)


ax = plt.subplot(5,1,1,projection=ccrs.NorthPolarStereo(central_longitude=180))

# ax.set_extent([0, 360, 35, 90], ccrs.PlateCarree())

ax.set_extent([94+180, 86+180, 50, 90], ccrs.PlateCarree())
colorbarvalues = ax.imshow(map_h3p*1.60804431421426, origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='hot',norm=colours.PowerNorm(0.5,vmin=0.1,vmax=1.2))
ax.gridlines()
# ax.text(310,50,'a')

rect = patches.Rectangle((89.5+180,80-0.5), 1, 1, linewidth=2, edgecolor=plotcol[0],facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)
rect = patches.Rectangle((89.5+180,70-0.5), 1, 1, linewidth=2, edgecolor=plotcol[1],facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)
rect = patches.Rectangle((89.5+180,60-0.5), 1, 1, linewidth=2, edgecolor=plotcol[2],facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)

latstr=np.array(['85','80','75','70','65','60','55'])


for latlat in range(7):
    if latlat < 2: clat = 'black'
    else: clat='white'
    ax.annotate(latstr[latlat]+r'$^{\circ}$N',
            xy=(1, 1), xycoords='axes fraction',
            xytext=(0.096+0.125*latlat, 0.85), textcoords='axes fraction',c=clat
            )


ax.annotate(r'180$^{\circ}$W',
        xy=(1, 1), xycoords='axes fraction',
        xytext=(0.92, 0.49), textcoords='axes fraction',c=clat
        )


latitudes=np.array([80,70,60])


cax = plt.axes([0.2, 0.73, 0.6, 0.025])
    # cax.set_ymargin(0.1)
cbar=plt.colorbar(colorbarvalues,cax=cax,label='Normalized auroral brightness',orientation='horizontal')


bbox_props3_1 = dict(boxstyle="round,pad=0.3", fc="antiquewhite", ec="sandybrown", lw=2)


for i in range(3):

    
    xx=(90+latitudes[i])*scale#(left/right)
    yy=270*scale#(up/down)
    
    
    # xx=xxx+160*scale#(left/right)
    # yy=yyy+20#(up/down)
    # if planetmask[yy,xx] > 0:
    spec = geo.convert(wave, spec_cube[yy, xx,:])  #auroral
    
    
    spec_sub_pre=np.hstack([spec[whw_331],spec[whw_353],spec[whw_360],spec[whw_371],spec[whw_390],spec[whw_392],spec[whw_3953]])
    
    
    spec[whw_331]=spec[whw_331]#-bck_331_whw
    
    
    
    #3.66
    bck_353_whw = f_353(wave[whw_353])
    bck_353_whw=bck_353_whw/np.quantile(bck_353_whw[-10:],0.75)*(np.quantile(spec[whw_353][-10:],0.75))
    # bck_353_whw=bck_353_whw/np.quantile(bck_353[:,1],0.85)*np.quantile(spec[whw_353],0.85)
    spec[whw_353]=spec[whw_353]-bck_353_whw
    
    
    #3.66
    bck_360_whw = f_360(wave[whw_360])
    bck_360_whw=bck_360_whw/np.quantile(bck_360_whw[-10:],0.75)*(np.quantile(spec[whw_360][-10:],0.75))
    # bck_360_whw=bck_360_whw/np.quantile(bck_360[:,1],0.85)*np.quantile(spec[whw_360],0.85)
    spec[whw_360]=spec[whw_360]-bck_360_whw
    
    #3.66
    bck_366_whw = f_366(wave[whw_366])
    bck_366_whw=bck_366_whw/np.quantile(bck_366_whw[-9:],0.75)*(np.quantile(spec[whw_366][-9:],0.75)-np.quantile(spec[whw_366][:5],0.5))
    # bck_366_whw=bck_366_whw/np.quantile(bck_366[:,1],0.85)*np.quantile(spec[whw_366],0.85)
    spec[whw_366]=spec[whw_366]-bck_366_whw
    
    
    #3.71
    bck_371_372_whw = f_371(wave[whw_371])
    bck_371_372_whw=bck_371_372_whw/np.quantile(bck_371_372_whw[-5:],0.75)*(np.quantile(spec[whw_371][-5:],0.75))
    # bck_371_372_whw=bck_371_372_whw/np.quantile(bck_371_372[:,1],0.5)*np.quantile(spec[whw_388],0.55)
    spec[whw_371]=spec[whw_371]-bck_371_372_whw
    
    # # lets make some background for h3ppy to fit no h3p to
    # whw_null=np.arange(500)
    # wave_null=wave[whw_null]
    # spec_null = np.zeros_like(wave_null)+np.random.normal(0, 1e-7, wave_null.shape) 
    
    # # lets make some background for h3ppy to fit no h3p to
    # whw_null2=np.arange(150)+1300
    # wave_null2=wave[whw_null2]
    # spec_null2 = np.zeros_like(wave_null2)+np.random.normal(0, 1e-7, wave_null2.shape) 
    
    
    #3.90
    bck_390_391_whw = f_390(wave[whw_390])
    bck_390_391_whw=bck_390_391_whw/np.quantile(bck_390_391_whw[0:6],0.75)*(np.quantile(spec[whw_390][0:6],0.75)-np.quantile(spec[whw_390][-4:],0.5))
    spec[whw_390]=spec[whw_390]-bck_390_391_whw
    
    
    #3.92
    bck_392_3935_whw = f_392(wave[whw_392])
    bck_392_3935_whw=bck_392_3935_whw/np.quantile(bck_392_3935[:,1],0.75)*np.quantile(spec[whw_392][0:6],0.75)
    spec[whw_392]=spec[whw_392]-bck_392_3935_whw
    
    
    #3.953
    bck_393_3958_whw = f_3953(wave[whw_3953])
    bck_393_3958_whw=bck_393_3958_whw/np.quantile(bck_393_3958[:,1],0.75)*np.quantile(spec[whw_3953],0.75)
    spec[whw_3953]=spec[whw_3953]-bck_393_3958_whw
    
    
    
    # spec_sub_back=np.hstack([spec[whw_331]*0,bck_353_whw,bck_360_whw,bck_371_372_whw,bck_390_391_whw,bck_392_3935_whw,bck_393_3958_whw])
    # why nospec[whw_366] was it bad????
    
    # not used - just imposible to pull out a signal from teh background
    # #3.966
    # bck_396_398_whw = f_3966(wave[whw_3966])
    # bck_396_398_whw=bck_396_398_whw/np.quantile(bck_396_398[:,1],0.75)*np.quantile(spec[whw_3966][0:10],0.75)
    # spec[whw_3966]=spec[whw_3966]-bck_396_398_whw
       
        
    # not used -- to close to brightness edge to be trustable
    #3.98
    # bck_398_399_whw = f_398(wave[whw_398])
    # bck_398_399_whw=bck_398_399_whw/np.quantile(bck_398_399[:,1],0.75)*np.quantile(spec[whw_398],0.75)
    # spec[whw_398]=spec[whw_398]-bck_398_399_whw
    
    
    # spec_sub=np.hstack([spec_null,spec[whw_353],spec[whw_360],spec[whw_371],spec_null2,spec[whw_390],spec[whw_392],spec[whw_3953]])
    # whw_sub=np.hstack([whw_null,whw_353,whw_360,whw_371,whw_null2,whw_390,whw_392,whw_3953])
    
    # #  i useed this to finish the map
    
    wavepos=np.arange(len(spec))
    
    spec_sub=np.hstack([spec[whw_331],spec[whw_353],spec[whw_360],spec[whw_371],spec[whw_390],spec[whw_392],spec[whw_3953]])
    whw_sub=np.hstack([whw_331,whw_353,whw_360,whw_371,whw_390,whw_392,whw_3953])
    spec_sub_back=np.hstack([spec[whw_331]*0,bck_353_whw,bck_360_whw,bck_371_372_whw,bck_390_391_whw,bck_392_3935_whw,bck_393_3958_whw])
    
    spec_sub_pos=np.hstack([wavepos[whw_331],wavepos[whw_353],wavepos[whw_360],wavepos[whw_371],wavepos[whw_390],wavepos[whw_392],wavepos[whw_3953]])

    spec_sub_pos_delta=spec_sub_pos-np.roll(spec_sub_pos,1)    
    
    spec_sub_steps=np.argwhere(spec_sub_pos_delta > 2)
    

    wave_sub=wave[whw_sub]

    spec_wave_delta=wave_sub-np.roll(wave_sub,1)    

    spec_wave_steps=np.argwhere(np.abs(spec_wave_delta) > 1e-3)


    wavewhwmin=np.zeros([6])
    wavewhwmax=np.zeros([6])
    
    for iw in range(6): wavewhwmax[iw]=wave[np.argwhere(wave == wave_sub[spec_wave_steps[iw]-1])]
    for iw in range(6): wavewhwmin[iw]=wave[np.argwhere(wave == wave_sub[spec_wave_steps[iw]])]
   
    wavewhwmax=np.roll(wavewhwmax,-1)

    
    # spec_sub=np.hstack([spec[whw_331],spec[whw_353],spec[whw_360],spec[whw_371],spec[whw_390],spec[whw_392],spec[whw_3953],spec[whw_4]])
    # whw_sub=np.hstack([whw_331,whw_353,whw_360,whw_371,whw_390,whw_392,whw_3953,whw_4])
    
    
    # I used this for the majority of the map
    # spec_sub=np.hstack([spec[whw_353],spec[whw_360],spec[whw_371],spec[whw_390],spec[whw_392],spec[whw_3953]])
    # whw_sub=np.hstack([whw_353,whw_360,whw_371,whw_390,whw_392,whw_3953])
    
    # spec_sub=np.hstack([spec[whw_353],spec[whw_360],spec[whw_371],spec[whw_390]])
    # whw_sub=np.hstack([whw_353,whw_360,whw_371,whw_390])
    
    
    reload(ch4)
    
    ch4fit = ch4.fit_non_LTE_CH4_JWST(wave)
    fit = ch4fit.fit(wave, spec)
    # fit[whw_null]=0.
    # fit[whw_null2]=0.
    
    
    fit_sub = fit[whw_sub]
    
    fit_subA = fit[whw_353]
    fit_subB = fit[whw_3953]
    
    
    
    # wavewhw_delta=wave_sub-np.roll(wave_sub,1)  
    # wavewhw_steps=np.argwhere(np.abs(wavewhw_delta) > 1e-3)

    subwave = wave_sub[np.isfinite(fit_sub)]
    subpos = spec_sub_pos[np.isfinite(fit_sub)]
    subwave_delta=subwave-np.roll(subwave,1)    
    subwave_steps=np.argwhere(np.abs(subwave_delta) > 1e-3)


    subspec = (spec_sub - fit_sub)[np.isfinite(fit_sub)]
    
    subback = fit_sub[np.isfinite(fit_sub)]
    
    subspecA = (spec[whw_353]-fit_subA)[np.isfinite(fit_subA)]
    subspecB = (spec[whw_3953]-fit_subB)[np.isfinite(fit_subB)]
    
    rq_ratio = np.sum(subspecA)/np.sum(subspecB)
    rq_temp=temp_ratio_factors[0]+temp_ratio_factors[1]*rq_ratio+temp_ratio_factors[2]*rq_ratio**2
    
    if rq_temp < 350: rq_temp=350
    if rq_temp > 550: rq_temp=550
    
    rq_int = np.sum(subspecA)+np.sum(subspecB)
    rq_scaling=np.exp(rq_int_ratio_factors[0]+rq_int_ratio_factors[1]*rq_temp+rq_int_ratio_factors[2]*rq_temp**2+rq_int_ratio_factors[3]*rq_temp**3)
    
    rq_density=rq_int/rq_scaling*1.5e16
    
    waveposmin=np.zeros([9])
    waveposmax=np.zeros([9])
    
    for iw in range(9): waveposmax[iw]=(np.argwhere(wave_sub == subwave[subwave_steps[iw]-1]))
    for iw in range(9): waveposmin[iw]=(np.argwhere(wave_sub == subwave[subwave_steps[iw]]))
    
    waveposmax=np.roll(waveposmax,-1)
    
    
    
    
    print(rq_ratio,rq_temp,rq_int,np.sum(spec[whw_353])/np.sum(spec[whw_3953]),rq_scaling,rq_density/1e15)
    # rq_temp=450
    # rq_density=rq_density*30
    h3p = h3ppy.h3p()
    # Fit the residual H3+ spectrum
    h3p.set(wave=subwave, data=subspec)
    h3p.set(T=rq_temp, N=rq_density)
    prefit_model = h3p.model()
    
    # pre_temperature[yy,xx_x]=rq_temp
    # pre_density[yy,xx_x]=rq_density
    # pre_totalE[yy,xx_x]=h3p.total_emission()
    
    ratio_mask = np.zeros_like(prefit_model)+1.
    
    # if (np.sum(subspec[prefit_model > 0.1])/np.sum(ratio_mask[prefit_model > 0.1])) / (np.sum(subspec[prefit_model < 0.1])/np.sum(ratio_mask[prefit_model < 0.1])) > 10:
    intensity=np.sum(subspec)
    ch4_fun=np.nansum(ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling)
    ch4_hot=np.nansum(ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling)
    totalE=h3p.total_emission()
                
    h3p = h3ppy.h3p()
    # Fit the residual H3+ spectrum
    h3p.set(wave=subwave, data=subspec)
    h3p.set(T=rq_temp, N=rq_density)
    # h3p.guess_density()
    h3p_fit = h3p.fit()
    h3p_model=h3p.model()
    
    intensity=np.sum(subspec)
    ch4_fun=np.nansum(ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling)
    ch4_hot=np.nansum(ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling)
    totalE=h3p.total_emission()
    
    vars, errs = h3p.get_results(verbose=False)
    
    # Fit the residual H3+ spectrum
    
    # if vars == False: 
    #     print(xx_x,xx/scale,yy/scale,' -Fit failed- ',round(pre_temperature[yy,xx_x], 3),', ',round(pre_density[yy,xx_x]/1e15, 3),', ',round(pre_totalE[yy,xx_x]/1e-6, 3) )
    # else:
    #     temperature[yy,xx_x]=vars['temperature']
    #     density[yy,xx_x]=vars['density']
    #     temperature_error[yy,xx_x]=errs['temperature']
    #     density_error[yy,xx_x]=errs['density']
    #     totalE[yy,xx_x]=h3p.total_emission()
        
    #     print( xx/scale,', ',yy/scale,', ',round(temperature[yy,xx_x], 3),', ',round(density[yy,xx_x]/1e15, 3),', ',round(totalE[yy,xx_x]/1e-6, 3) )
    
    
    # fig, (ax1, ax2,ax3,ax4) = plt.subplots(1, 4, figsize=(12, 4), gridspec_kw={'width_ratios': [3, 1,1,1]})
    
    
    
    

    axs[1+1,i].plot(wave,spec,label='raw data at '+str(latitudes[i])+'$^{\circ}$N',c=plotcol[i],linewidth=0.5,drawstyle='steps-mid')
    axs[1+1,i].set_xticks([3.0,3.5,4.0,4.5,5.0],['3.0$\mu$m','3.5$\mu$m','4.0$\mu$m','4.5$\mu$m','5.0$\mu$m'])
    # ax1.plot(wave,nam[270,80+90,:]/nac[270,80+90,:])
    
    xfix=np.arange(spec_sub_pre[300:].size)+300
    # ax1.plot(ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling,label='methane_fund',linestyle='dotted')
    # ax1.plot(ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling,label='methane_hot',linestyle='dotted')
    axs[2+1,i].plot(xfix,spec_sub_pre[300:],c=plotcol[i],drawstyle='steps-mid')
    axs[2+1,i].plot(xfix,spec_sub_back[300:],label='PSG model',linestyle='--',c='darkred')
    # axs[2+1,i].legend()
    # axs[2+1,i].set_yticks([])

    
    axs[3+1,i].plot(spec_sub,c=plotcol[i])
    axs[3+1,i].plot(ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling,label='methane_fund',linestyle='dotted')
    axs[3+1,i].plot(ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling,label='methane_hot',linestyle='dotted')
    
    axs[3+1,i].plot(fit_sub,label='back fit',linestyle='--',c='xkcd:vivid purple',drawstyle='steps-mid')
    # axs[3+1,i].legend()
    
    # ax2.sharey(ax1)
    # ax2.plot(wave,nam[270,70+90,:]/nac[270,70+90,:])
    # # ax2.set_yticks([])
    # ax3.sharey(ax1)
    axs[4+1,i].plot(subspec,c=plotcol[i])
    axs[4+1,i].plot(h3p_model,label='H3+ fit',linestyle='--',c='xkcd:purpley blue',drawstyle='steps-mid')


    # axs[4+1,i].legend()
    print(i,np.sum(subspec))
    # axs[3+1,1].set_yticks([])
    axs[1+1,i].legend()
    if i > 0:
        axs[1+1,i].set_yticks([])
        # axs[1+1,i].sharey(axs[1+1,0])
        
    for j in range(3): 
        if i > 0:
            axs[j+3,i].set_yticks([])
            # axs[j+2,i].sharey(axs[j+2,0])
        else:
            axs[j+3,0].legend(loc='upper center')

    axs1=axs[1+1,2].get_ylim()
    # this gets broken later in an inexplicable way, only if you run:
        # axs[1+1,2].fill([wavewhwmin[iw],wavewhwmin[iw],wavewhwmax[iw],wavewhwmax[iw]] , [axs1[0],axs1[1],axs1[1],axs1[0] ] ,c='whitesmoke')
    # so to fix, I'm hardwiring the values here:
    axs1b= (-0.0008817610230538503,0.01740230019661231)

    axs[1+1,0].set_ylim(axs[1+1,2].get_ylim())
    axs[1+1,1].set_ylim(axs[1+1,2].get_ylim())


    axs2=axs[2+1,2].get_ylim()

    axs[2+1,0].set_ylim(axs[2+1,2].get_ylim())
    axs[2+1,1].set_ylim(axs[2+1,2].get_ylim())


    axs3=axs[3+1,0].get_ylim()

    axs[3+1,0].set_ylim(axs3)
    axs[3+1,1].set_ylim(axs3)
    axs[3+1,2].set_ylim(axs3)

    axs4=axs[4+1,0].get_ylim()

    axs[4+1,1].set_ylim(axs[4+1,0].get_ylim())
    axs[4+1,2].set_ylim(axs[4+1,0].get_ylim())
    
    axs3=axs[3+1,0].get_ylim()

    subgroup=np.array(['1','2','3','4','5','6'])

    for k in range(len(spec_sub_steps)): 
        axs[2+1,i].plot([spec_sub_steps[k],spec_sub_steps[k]],[axs2[0],axs2[1]],c='grey',linestyle='dotted')


    for k in range(len(spec_sub_steps)): 
        axs[3+1,i].plot([spec_sub_steps[k],spec_sub_steps[k]],[axs3[0],axs3[1]],c='grey',linestyle='dotted')

    # axs[1+1,0].fill(x,y)
for k in range(4): 
    axs[2+1,0].text((spec_sub_steps[k]+spec_sub_steps[k+1])/2 , 0.2 ,subgroup[k+1],c='grey',zorder=1,va='center',ha='center')
    axs[2+1,1].text((spec_sub_steps[k]+spec_sub_steps[k+1])/2 , 0.3 ,subgroup[k+1],c='grey',zorder=1,va='center',ha='center')
    axs[2+1,2].text((spec_sub_steps[k]+spec_sub_steps[k+1])/2 , 0.3 ,subgroup[k+1],c='grey',zorder=1,va='center',ha='center')

axs[2+1,0].text(spec_sub_steps[4]+30 , 0.2 ,subgroup[k+2],c='grey',zorder=1,va='center',ha='center')
axs[2+1,1].text(spec_sub_steps[4]+30 , 0.3 ,subgroup[k+2],c='grey',zorder=1,va='center',ha='center')
axs[2+1,2].text(spec_sub_steps[4]+30 , 0.3 ,subgroup[k+2],c='grey',zorder=1,va='center',ha='center')


print(axs1[0],axs1[1])
for iw in range(6): 
    axs[1+1,0].fill([wavewhwmin[iw],wavewhwmin[iw],wavewhwmax[iw],wavewhwmax[iw]] , [axs1[0],axs1[1],axs1[1],axs1[0] ] ,c='whitesmoke',zorder=0)
    axs[1+1,1].fill([wavewhwmin[iw],wavewhwmin[iw],wavewhwmax[iw],wavewhwmax[iw]] , [axs1[0],axs1[1],axs1[1],axs1[0] ] ,c='whitesmoke',zorder=0)
    axs[1+1,2].fill([wavewhwmin[iw],wavewhwmin[iw],wavewhwmax[iw],wavewhwmax[iw]] , [axs1[0],axs1[1],axs1[1],axs1[0] ] ,c='whitesmoke',zorder=0)
 
    axs[1+1,0].text((wavewhwmin[iw]+wavewhwmax[iw])/2,10,subgroup[iw],c='grey',va='center',ha='center')
    axs[1+1,1].text((wavewhwmin[iw]+wavewhwmax[iw])/2,10,subgroup[iw],c='grey',va='center',ha='center')
    axs[1+1,2].text((wavewhwmin[iw]+wavewhwmax[iw])/2,10,subgroup[iw],c='grey',va='center',ha='center')
 
 

# axs[1+1,1].plot([(wavewhwmax[4]+wavewhwmin[5])/2,(wavewhwmax[4]+wavewhwmin[5])/2] , [axs1[0],axs1[1]] ,c='white',linewidth=0.5,zorder=1)
# axs[1+1,2].plot([(wavewhwmax[4]+wavewhwmin[5])/2,(wavewhwmax[4]+wavewhwmin[5])/2] , [axs1b[0],axs1b[1]] ,c='white',linewidth=0.5,zorder=1)

print(axs1[0],axs1[1])

    # axs[1+1,0].fill([np.min(wave[whw_331]),np.min(wave[whw_331]),np.max(wave[whw_331]),np.max(wave[whw_331])] , [axs1[0],axs1[1],axs1[1],axs1[0] ] ,c='whitesmoke')
# axs[1+1,0].plot( [np.min(wave[whw_331]),np.min(wave[whw_331])], [axs1[0],axs1[1],axs1[1],axs[0] ] ,c='grey',linestyle='dotted' )
# axs[1+1,0].plot( [np.max(wave[whw_331]),np.max(wave[whw_331])], [axs[1+1,2].get_ylim()[0],axs[1+1,2].get_ylim()[1]] ,c='grey',linestyle='dotted' )

subgroup=np.array(['a','c','d','e','f','g','h','i','j','k','l'])
for iw in range(9): 
    axs[3+1,0].fill([waveposmin[iw],waveposmin[iw],waveposmax[iw],waveposmax[iw]] , [axs3[0],axs3[1],axs3[1],axs3[0] ] ,c='lavenderblush')
    axs[3+1,1].fill([waveposmin[iw],waveposmin[iw],waveposmax[iw],waveposmax[iw]] , [axs3[0],axs3[1],axs3[1],axs3[0] ] ,c='lavenderblush')
    axs[3+1,2].fill([waveposmin[iw],waveposmin[iw],waveposmax[iw],waveposmax[iw]] , [axs3[0],axs3[1],axs3[1],axs3[0] ] ,c='lavenderblush')

for iw in range(1): 
    axs[3+1,0].text((waveposmin[iw]+waveposmax[iw])/2.,-0.05,subgroup[iw],c='plum',va='center',ha='center')
    axs[3+1,1].text((waveposmin[iw]+waveposmax[iw])/2.,-0.05,subgroup[iw],c='plum',va='center',ha='center')
    axs[3+1,2].text((waveposmin[iw]+waveposmax[iw])/2.,-0.05,subgroup[iw],c='plum',va='center',ha='center')
    axs[4+1,0].text((subwave_steps[iw]+subwave_steps[iw+1])/2,0.15,subgroup[iw],c='plum',va='center',ha='center')
    axs[4+1,1].text((subwave_steps[iw]+subwave_steps[iw+1])/2,0.20,subgroup[iw],c='plum',va='center',ha='center')
    axs[4+1,2].text((subwave_steps[iw]+subwave_steps[iw+1])/2,0.20,subgroup[iw],c='plum',va='center',ha='center')

for iw in range(7): 
    axs[3+1,0].text((waveposmin[iw+2]+waveposmax[iw+2])/2.,-0.05,subgroup[iw+2],c='plum',va='center',ha='center')
    axs[3+1,1].text((waveposmin[iw+2]+waveposmax[iw+2])/2.,-0.05,subgroup[iw+2],c='plum',va='center',ha='center')
    axs[3+1,2].text((waveposmin[iw+2]+waveposmax[iw+2])/2.,-0.05,subgroup[iw+2],c='plum',va='center',ha='center')
for iw in range(6): 
    axs[4+1,0].text((subwave_steps[iw+2]+subwave_steps[iw+2+1])/2,0.15,subgroup[iw+2],c='plum',va='center',ha='center')
    axs[4+1,1].text((subwave_steps[iw+2]+subwave_steps[iw+2+1])/2,0.20,subgroup[iw+2],c='plum',va='center',ha='center')
    axs[4+1,2].text((subwave_steps[iw+2]+subwave_steps[iw+2+1])/2,0.20,subgroup[iw+2],c='plum',va='center',ha='center')


# split at 75, and at 195


axs[4+1,1].set_xlabel('Wavelength bins')
axs[2+1,0].set_ylabel(r'Spectral radience [mW/m$^2$/$\mu$m/sr]')


axs[3+1,0].text((waveposmin[1]+195)/2.,-0.05,'b',c='plum',va='center',ha='center')
axs[3+1,0].text((waveposmax[1]+195)/2.,-0.05,'c',c='plum',va='center',ha='center')
axs[3+1,1].text((waveposmin[1]+195)/2.,-0.05,'b',c='plum',va='center',ha='center')
axs[3+1,1].text((waveposmax[1]+195)/2.,-0.05,'c',c='plum',va='center',ha='center')
axs[3+1,2].text((waveposmin[1]+195)/2.,-0.05,'b',c='plum',va='center',ha='center')
axs[3+1,2].text((waveposmax[1]+195)/2.,-0.05,'c',c='plum',va='center',ha='center')


axs[3+1,0].plot([195,195] , [axs1[0],axs1[1]] ,c='white',linewidth=1,zorder=1,linestyle='--')
axs[3+1,1].plot([195,195] , [axs1[0],axs1[1]] ,c='white',linewidth=1,zorder=1,linestyle='--')
axs[3+1,2].plot([195,195] , [axs1b[0],axs1b[1]] ,c='white',linewidth=1,zorder=1,linestyle='--')


for iw in range(8): 

    axs[4+1,0].plot([subwave_steps[iw+1],subwave_steps[iw+1]] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')
    axs[4+1,1].plot([subwave_steps[iw+1],subwave_steps[iw+1]] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')
    axs[4+1,2].plot([subwave_steps[iw+1],subwave_steps[iw+1]] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')
axs[4+1,0].plot([75,75] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')
axs[4+1,1].plot([75,75] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')
axs[4+1,2].plot([75,75] , [axs4[0],axs4[1]] ,c='plum',linestyle='dotted')


# split at 75
axs[4+1,0].text((subwave_steps[1]+75)/2,0.15,'b',c='plum',va='center',ha='center')
axs[4+1,0].text((subwave_steps[2]+75)/2,0.15,'c',c='plum',va='center',ha='center')
axs[4+1,1].text((subwave_steps[1]+75)/2,0.20,'b',c='plum',va='center',ha='center')
axs[4+1,1].text((subwave_steps[2]+75)/2,0.20,'c',c='plum',va='center',ha='center')
axs[4+1,2].text((subwave_steps[1]+75)/2,0.20,'b',c='plum',va='center',ha='center')
axs[4+1,2].text((subwave_steps[2]+75)/2,0.20,'c',c='plum',va='center',ha='center')
# axs[4+1,1].text(subwave_steps[8]+5,0.20,subgroup[8],c='plum',va='center',ha='center')
# axs[4+1,2].text(subwave_steps[8]+5,0.20,subgroup[8],c='plum',va='center',ha='center')


axs[4+1,0].text(subwave_steps[8]+5,0.15,subgroup[8],c='plum',va='center',ha='center')
axs[4+1,1].text(subwave_steps[8]+5,0.20,subgroup[8],c='plum',va='center',ha='center')
axs[4+1,2].text(subwave_steps[8]+5,0.20,subgroup[8],c='plum',va='center',ha='center')

axs[1+1,0].text(2.9,14.8,'I',bbox=bbox_props3_1,size='large')
axs[2+1,0].text(300,0.29,'II',bbox=bbox_props3_1,size='large')
axs[3+1,0].text(0,0.19,'III',bbox=bbox_props3_1,size='large')
axs[4+1,0].text(0,0.19,'IV',bbox=bbox_props3_1,size='large')

fig2.savefig('asd_temper_fig1.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 
# fig2.savefig('asd_temper_fig1.pdf', dpi=300, facecolor='white') 

plt.show()






# %%

# fig2.savefig('asd_fig2b.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 

# plt.show()



# %%


# ax = plt.subplot(242,projection=ccrs.NorthPolarStereo(central_longitude=180))
# # ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())
# # ax.imshow(diff_h3p, origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic',vmax=np.nanmax(diff_h3p),vmin=(-1)*np.nanmax(diff_h3p))
# ax.gridlines()
# # ax.text(310,50,'a')
# ax.plot([100,100],[0,50], transform=ccrs.PlateCarree())

